{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Example of loading a custom tree model into SHAP\n", "\n", "This notebook shows how to pass a custom tree ensemble model into SHAP for explanation." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import graphviz\n", "import numpy as np\n", "import scipy\n", "import sklearn\n", "\n", "import shap" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Simple regression tree model\n", "\n", "Here we define a simple regression tree and then load it into SHAP as a custom model." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
DecisionTreeRegressor(max_depth=2)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "DecisionTreeRegressor(max_depth=2)" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X, y = shap.datasets.adult()\n", "\n", "orig_model = sklearn.tree.DecisionTreeRegressor(max_depth=2)\n", "orig_model.fit(X, y)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "Tree\n", "\n", "\n", "\n", "0\n", "\n", "x\n", "5\n", " ≤ 3.5\n", "squared_error = 0.183\n", "samples = 32561\n", "value = 0.241\n", "\n", "\n", "\n", "1\n", "\n", "x\n", "8\n", " ≤ 7073.5\n", "squared_error = 0.062\n", "samples = 17800\n", "value = 0.066\n", "\n", "\n", "\n", "0->1\n", "\n", "\n", "True\n", "\n", "\n", "\n", "4\n", "\n", "x\n", "2\n", " ≤ 12.5\n", "squared_error = 0.248\n", "samples = 14761\n", "value = 0.451\n", "\n", "\n", "\n", "0->4\n", "\n", "\n", "False\n", "\n", "\n", "\n", "2\n", "\n", "squared_error = 0.047\n", "samples = 17482\n", "value = 0.05\n", "\n", "\n", "\n", "1->2\n", "\n", "\n", "\n", "\n", "\n", "3\n", "\n", "squared_error = 0.036\n", "samples = 318\n", "value = 0.962\n", "\n", "\n", "\n", "1->3\n", "\n", "\n", "\n", "\n", "\n", "5\n", "\n", "squared_error = 0.223\n", "samples = 10329\n", "value = 0.335\n", "\n", "\n", "\n", "4->5\n", "\n", "\n", "\n", "\n", "\n", "6\n", "\n", "squared_error = 0.2\n", "samples = 4432\n", "value = 0.724\n", "\n", "\n", "\n", "4->6\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dot_data = sklearn.tree.export_graphviz(orig_model, out_file=None, filled=True, rounded=True, special_characters=True)\n", "graph = graphviz.Source(dot_data)\n", "graph" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For more information what these attributes mean exactly, see the [scikit-learn documentation](https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html#tree-structure)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " children_left [ 1 2 -1 -1 5 -1 -1]\n", " children_right [ 4 3 -1 -1 6 -1 -1]\n", " children_default [ 4 3 -1 -1 6 -1 -1]\n", " features [ 5 8 -2 -2 2 -2 -2]\n", " thresholds [ 3.5000e+00 7.0735e+03 -2.0000e+00 -2.0000e+00 1.2500e+01 -2.0000e+00\n", " -2.0000e+00]\n", " values [[0.241]\n", " [0.066]\n", " [0.05 ]\n", " [0.962]\n", " [0.451]\n", " [0.335]\n", " [0.724]]\n", "node_sample_weight [32561. 17800. 17482. 318. 14761. 10329. 4432.]\n" ] } ], "source": [ "# extract the arrays that define the tree\n", "children_left = orig_model.tree_.children_left\n", "children_right = orig_model.tree_.children_right\n", "children_default = children_right.copy() # because sklearn does not use missing values\n", "features = orig_model.tree_.feature\n", "thresholds = orig_model.tree_.threshold\n", "values = orig_model.tree_.value.reshape(orig_model.tree_.value.shape[0], 1)\n", "node_sample_weight = orig_model.tree_.weighted_n_node_samples\n", "\n", "print(\" children_left\", children_left) # note that negative children values mean this is a leaf node\n", "print(\" children_right\", children_right)\n", "print(\" children_default\", children_default)\n", "print(\" features\", features)\n", "print(\" thresholds\", thresholds.round(3)) # -2 means the node is a leaf node\n", "print(\" values\", values.round(3))\n", "print(\"node_sample_weight\", node_sample_weight)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# define a custom tree model\n", "tree_dict = {\n", " \"children_left\": children_left,\n", " \"children_right\": children_right,\n", " \"children_default\": children_default,\n", " \"features\": features,\n", " \"thresholds\": thresholds,\n", " \"values\": values,\n", " \"node_sample_weight\": node_sample_weight,\n", "}\n", "model = {\"trees\": [tree_dict]}" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "explainer = shap.TreeExplainer(model)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Make sure that the ingested SHAP model (a TreeEnsemble object) makes the\n", "# same predictions as the original model\n", "assert np.abs(explainer.model.predict(X) - orig_model.predict(X)).max() < 1e-4" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# make sure the SHAP values sum up to the model output (this is the local accuracy property)\n", "assert np.abs(explainer.expected_value + explainer.shap_values(X).sum(1) - orig_model.predict(X)).max() < 1e-4" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Simple GBM classification model (with 2 trees)\n", "\n", "Here we define a simple gradient-boosting classifier and then load it into SHAP as a custom model." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
GradientBoostingClassifier(n_estimators=2)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "GradientBoostingClassifier(n_estimators=2)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X2, y2 = shap.datasets.adult()\n", "orig_model2 = sklearn.ensemble.GradientBoostingClassifier(n_estimators=2)\n", "orig_model2.fit(X2, y2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Pull the info of the first tree" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " children_left1 [ 1 2 3 -1 -1 6 -1 -1 9 10 -1 -1 13 -1 -1]\n", " children_right1 [ 8 5 4 -1 -1 7 -1 -1 12 11 -1 -1 14 -1 -1]\n", " children_default1 [ 8 5 4 -1 -1 7 -1 -1 12 11 -1 -1 14 -1 -1]\n", " features1 [ 5 8 2 -2 -2 0 -2 -2 2 8 -2 -2 8 -2 -2]\n", " thresholds1 [ 3.5000e+00 7.0735e+03 1.2500e+01 -2.0000e+00 -2.0000e+00 2.0500e+01\n", " -2.0000e+00 -2.0000e+00 1.2500e+01 5.0955e+03 -2.0000e+00 -2.0000e+00\n", " 5.0955e+03 -2.0000e+00 -2.0000e+00]\n", " values1 [[-0. ]\n", " [-0.175]\n", " [-0.191]\n", " [-1.177]\n", " [-0.503]\n", " [ 0.721]\n", " [-0.223]\n", " [ 4.013]\n", " [ 0.211]\n", " [ 0.094]\n", " [ 0.325]\n", " [ 4.048]\n", " [ 0.483]\n", " [ 2.372]\n", " [ 4.128]]\n", "node_sample_weight1 [3.2561e+04 1.7800e+04 1.7482e+04 1.4036e+04 3.4460e+03 3.1800e+02\n", " 5.0000e+00 3.1300e+02 1.4761e+04 1.0329e+04 9.8070e+03 5.2200e+02\n", " 4.4320e+03 3.7540e+03 6.7800e+02]\n" ] } ], "source": [ "tree_tmp = orig_model2.estimators_[0][0].tree_\n", "\n", "# extract the arrays that define the tree\n", "children_left1 = tree_tmp.children_left\n", "children_right1 = tree_tmp.children_right\n", "children_default1 = children_right1.copy() # because sklearn does not use missing values\n", "features1 = tree_tmp.feature\n", "thresholds1 = tree_tmp.threshold\n", "values1 = tree_tmp.value.reshape(tree_tmp.value.shape[0], 1)\n", "node_sample_weight1 = tree_tmp.weighted_n_node_samples\n", "\n", "print(\" children_left1\", children_left1) # note that negative children values mean this is a leaf node\n", "print(\" children_right1\", children_right1)\n", "print(\" children_default1\", children_default1)\n", "print(\" features1\", features1)\n", "print(\" thresholds1\", thresholds1.round(3))\n", "print(\" values1\", values1.round(3))\n", "print(\"node_sample_weight1\", node_sample_weight1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Pull the info of the second tree" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " children_left2 [ 1 2 3 -1 -1 6 -1 -1 9 10 -1 -1 13 -1 -1]\n", " children_right2 [ 8 5 4 -1 -1 7 -1 -1 12 11 -1 -1 14 -1 -1]\n", " children_default2 [ 8 5 4 -1 -1 7 -1 -1 12 11 -1 -1 14 -1 -1]\n", " features2 [ 5 8 2 -2 -2 0 -2 -2 2 8 -2 -2 8 -2 -2]\n", " thresholds2 [ 3.5000e+00 7.0735e+03 1.3500e+01 -2.0000e+00 -2.0000e+00 2.0500e+01\n", " -2.0000e+00 -2.0000e+00 1.2500e+01 5.0955e+03 -2.0000e+00 -2.0000e+00\n", " 5.0955e+03 -2.0000e+00 -2.0000e+00]\n", " values2 [[-1.000e-03]\n", " [-1.580e-01]\n", " [-1.720e-01]\n", " [-1.062e+00]\n", " [ 1.360e-01]\n", " [ 6.420e-01]\n", " [-2.030e-01]\n", " [ 2.993e+00]\n", " [ 1.880e-01]\n", " [ 8.400e-02]\n", " [ 2.870e-01]\n", " [ 3.015e+00]\n", " [ 4.310e-01]\n", " [ 1.895e+00]\n", " [ 3.066e+00]]\n", "node_sample_weight2 [3.2561e+04 1.7800e+04 1.7482e+04 1.6560e+04 9.2200e+02 3.1800e+02\n", " 5.0000e+00 3.1300e+02 1.4761e+04 1.0329e+04 9.8070e+03 5.2200e+02\n", " 4.4320e+03 3.7540e+03 6.7800e+02]\n" ] } ], "source": [ "tree_tmp = orig_model2.estimators_[1][0].tree_\n", "\n", "# extract the arrays that define the tree\n", "children_left2 = tree_tmp.children_left\n", "children_right2 = tree_tmp.children_right\n", "children_default2 = children_right2.copy() # because sklearn does not use missing values\n", "features2 = tree_tmp.feature\n", "thresholds2 = tree_tmp.threshold\n", "values2 = tree_tmp.value.reshape(tree_tmp.value.shape[0], 1)\n", "node_sample_weight2 = tree_tmp.weighted_n_node_samples\n", "\n", "print(\" children_left2\", children_left2) # note that negative children values mean this is a leaf node\n", "print(\" children_right2\", children_right2)\n", "print(\" children_default2\", children_default2)\n", "print(\" features2\", features2)\n", "print(\" thresholds2\", thresholds2.round(3))\n", "print(\" values2\", values2.round(3))\n", "print(\"node_sample_weight2\", node_sample_weight2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create a list of SHAP Trees" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "# define a custom tree model\n", "tree_dicts = [\n", " {\n", " \"children_left\": children_left1,\n", " \"children_right\": children_right1,\n", " \"children_default\": children_default1,\n", " \"features\": features1,\n", " \"thresholds\": thresholds1,\n", " \"values\": values1 * orig_model2.learning_rate,\n", " \"node_sample_weight\": node_sample_weight1,\n", " },\n", " {\n", " \"children_left\": children_left2,\n", " \"children_right\": children_right2,\n", " \"children_default\": children_default2,\n", " \"features\": features2,\n", " \"thresholds\": thresholds2,\n", " \"values\": values2 * orig_model2.learning_rate,\n", " \"node_sample_weight\": node_sample_weight2,\n", " },\n", "]\n", "model2 = {\n", " \"trees\": tree_dicts,\n", " \"base_offset\": scipy.special.logit(orig_model2.init_.class_prior_[1]),\n", " \"tree_output\": \"log_odds\",\n", " \"objective\": \"binary_crossentropy\",\n", " \"input_dtype\": np.float32, # this is what type the model uses the input feature data\n", " \"internal_dtype\": np.float64, # this is what type the model uses for values and thresholds\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Explain the custom model" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "# build a background dataset for us to use based on people near a 0.95 cutoff\n", "vs = np.abs(orig_model2.predict_proba(X2)[:, 1] - 0.95)\n", "inds = np.argsort(vs)\n", "inds = inds[:200]" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "# build an explainer that explains the probability output of the model\n", "explainer2 = shap.TreeExplainer(\n", " model2,\n", " X2.iloc[inds, :],\n", " feature_perturbation=\"interventional\",\n", " model_output=\"probability\",\n", ")" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# Make sure that the ingested SHAP model (a TreeEnsemble object) makes the\n", "# same predictions as the original model\n", "assert np.abs(explainer2.model.predict(X2, output=\"probability\") - orig_model2.predict_proba(X2)[:, 1]).max() < 1e-4" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "# make sure the sum of the SHAP values equals the model output\n", "shap_sum = explainer2.expected_value + explainer2.shap_values(X2.iloc[:, :]).sum(1)\n", "assert np.abs(shap_sum - orig_model2.predict_proba(X2)[:, 1]).max() < 1e-4" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }